"""Fully connected network."""
import numpy as np
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.fusion import *

BASE_WIDTH = 128


class NIN(nn.Module):
    def __init__(
        self, n_channel, n_classes, depth=2, width=1, batch_norm=False, dropout=False
    ):
        super(NIN, self).__init__()
        self.n_classes = n_classes
        self.depth = depth
        self.width = width
        self.batch_norm = batch_norm
        self.dropout = dropout

        self.modules = []
        n_in, n_out = n_channel, int(BASE_WIDTH * width)
        for i in range(depth):
            block = self._nin_block(n_in, n_out)
            self.modules.append(block)
            n_in = n_out
        self.modules.append(nn.Conv2d(n_out, n_classes, 1, stride=1, padding=0))
        self.model = nn.Sequential(*self.modules)

    @property
    def last_layer_name(self):
        return "model.{}.7".format(self.depth)

    def _nin_block(self, n_in, n_out):
        layers = []
        for i in range(3):
            k = 1 if i > 0 else 3
            pad = 0 if i > 0 else 1
            strides = 1 if i > 0 else 2
            layers.append(nn.Conv2d(n_in, n_out, k, stride=strides, padding=pad))
            n_in = n_out
            if self.batch_norm:
                layers.append(nn.BatchNorm2d(n_out))
            layers.append(nn.ReLU())
        if self.dropout:
            layers.append(nn.Dropout())
        return nn.Sequential(*layers)

    def forward(self, x):
        b, _, _, _ = x.shape
        x = self.model(x)
        return F.avg_pool2d(x, x.size()[2:]).reshape((b, -1))

    def get_last_layer_weights(self):
        state_dict = self.state_dict()
        n = self.depth
        return (state_dict['model.{}.weight'.format(n)],  state_dict['model.{}.bias'.format(n)])

    def combine_model(self, model_1, mode1_2, breakpoint=-1):
        if breakpoint == -1:
            breakpoint = self.depth
        state_dict_1 = model_1.state_dict()
        state_dict_2 = mode1_2.state_dict()
        new_state_dict = state_dict_2.copy()

        all_weight_names = list(state_dict_1.keys())
        target_weight_name = [
            n for n in all_weight_names if self.is_later_layers(n, breakpoint)
        ]
        if len(target_weight_name) == 0:
            raise ValueError("No weight to be replace")
        for n in target_weight_name:
            new_state_dict[n] = state_dict_1[n]
        self.load_state_dict(new_state_dict)

    def is_later_layers(self, name, layer_idx):
        splits = name.split("model.")
        group_idx = int(splits[1].split(".")[0])
        return group_idx >= layer_idx

    def get_intermediate(self, x):
        lookup = {}
        out = x
        for i, m in enumerate(self.modules):
            out = m(out)
            lookup['block{}'.format(i)] = out
        return lookup

    def fuse_batchnorm(self):
        assert self.batch_norm, "This model does not use batchnorm"
        d = self.state_dict()
        for k in d:
            if 'weight' not in k:
                continue
            if (k.strip('.weight') + '.running_mean') in d:
                continue
            parts = k.split('.')
            if len(parts) == 3:
                continue
            idx_1, idx_2 = parts[1:-1]
            layer_name = 'model.' + '.'.join([idx_1, idx_2])
            bn_layer_name = 'model.' + '.'.join([idx_1, str(eval(str(idx_2))+1)])
            conv_w, conv_b = d[layer_name + '.weight'], d[layer_name + '.bias']
            bn_rm = d[bn_layer_name + '.running_mean']
            bn_rv = d[bn_layer_name + '.running_var']
            bn_eps = 1e-5
            bn_w, bn_b = d[bn_layer_name + '.weight'], d[bn_layer_name + '.bias']
            w, b = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b)
            d[layer_name + '.weight'] = w
            d[layer_name + '.bias'] = b
            d[bn_layer_name + '.weight'] = torch.ones_like(bn_w)
            d[bn_layer_name + '.bias'] = torch.zeros_like(bn_b)
            d[bn_layer_name + '.running_mean'] = torch.zeros_like(bn_w)
            d[bn_layer_name + '.running_var'] = torch.ones_like(bn_b)
        self.load_state_dict(d)

    def fuse_batchnorm_into_model(self, model_no_bn):
        d = self.state_dict()
        d_no_bn = model_no_bn.state_dict()
        for k in d:
            parts = k.split('.')
            if len(parts) == 3:
                d_no_bn[k] = d[k]
                assert d_no_bn[k].shape == d[k].shape
                continue
            if 'weight' not in k:
                continue
            if (k.strip('.weight') + '.running_mean') in d:
                continue
            idx_1, idx_2 = parts[1:-1]
            layer_name = 'model.' + '.'.join([idx_1, idx_2])
            bn_layer_name = 'model.' + '.'.join([idx_1, str(eval(str(idx_2))+1)])
            conv_w, conv_b = d[layer_name + '.weight'], d[layer_name + '.bias']
            bn_rm = d[bn_layer_name + '.running_mean']
            bn_rv = d[bn_layer_name + '.running_var']
            bn_eps = 1e-5
            bn_w, bn_b = d[bn_layer_name + '.weight'], d[bn_layer_name + '.bias']
            w, b = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b)
            corresponding_layer_name = 'model.' + '.'.join([idx_1, str(eval(str(idx_2))//3*2)])
            assert d_no_bn[corresponding_layer_name+'.weight'].shape == conv_w.shape
            d_no_bn[corresponding_layer_name+'.weight'] = w
            d_no_bn[corresponding_layer_name+'.bias'] = b
        model_no_bn.load_state_dict(d_no_bn)


class OldNIN(nn.Module):
    def __init__(
        self, n_channel, n_classes, depth=2, width=1, batch_norm=False, dropout=False
    ):
        super(NIN, self).__init__()
        self.n_classes = n_classes
        self.depth = depth
        self.width = width
        self.batch_norm = batch_norm
        self.dropout = dropout

        self.modules = []

        self.modules += self._first_nin_block(n_channel)
        n_in, n_out = 192, int(BASE_WIDTH * width)
        for i in range(depth - 2):
            self.modules += self._nin_block(n_in)
            n_in = int(self.width * BASE_WIDTH)
        self.modules += self._last_nin_block(n_out)
        self.model = nn.Sequential(*self.modules)

    def _nin_block(self, n_in):
        width = int(self.width * BASE_WIDTH)
        layers = [
            nn.Conv2d(n_in, width, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(width) if self.batch_norm else nn.Identity(),
            nn.ReLU(inplace=True),
            nn.Conv2d(width, width, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(width) if self.batch_norm else nn.Identity(),
            nn.ReLU(inplace=True),
            nn.Conv2d(width, width, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(width) if self.batch_norm else nn.Identity(),
            nn.ReLU(inplace=True),
            nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
        ]
        if self.dropout:
            layers.append(nn.Dropout())
        return layers

    def _first_nin_block(self, n_in):
        base_layers = [
            nn.Conv2d(n_in, 192, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(192) if self.batch_norm else nn.Identity(),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(192) if self.batch_norm else nn.Identity(),
            nn.ReLU(inplace=True),
            nn.Conv2d(160, 192, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(192) if self.batch_norm else nn.Identity(),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        ]
        if self.dropout:
            layers.append(nn.Dropout())
        return base_layers

    def _last_nin_block(self, n_in):
        width = int(self.width * BASE_WIDTH)
        base_layers = [
            nn.Conv2d(n_in, width, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(width) if self.batch_norm else nn.Identity(),
            nn.ReLU(inplace=True),
            nn.Conv2d(width, width, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(width) if self.batch_norm else nn.Identity(),
            nn.ReLU(inplace=True),
            nn.Conv2d(width, self.n_classes, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(width) if self.batch_norm else nn.Identity(),
        ]
        return base_layers

    def forward(self, x):
        b, _, _, _ = x.shape
        x = self.model(x)
        return F.avg_pool2d(x, x.size()[2:]).reshape((b, -1))